Add Metal DLPack zero-copy sharing#3531
Conversation
|
Hi @XXXXRT666 — read through this PR after @awni redirected us here from #3548. The Wanted to offer some testing help that complements the PyTorch MPS bring-up you have: We maintain a downstream TileLang fork (https://github.com/DatasunriseOU/tilelang) whose TVM-FFI bridge exports
If useful, once the PR converges I can:
Tag me here when you'd like input — no rush, just don't want this to slip past once it's review-ready. (For the orthogonal |
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing). SHA: 33f52e635db5e6229060481d16a167230a1a474b PR: wjakob/nanobind#1338 Branch: metal-dlpack-cast
002360f to
4e16f1d
Compare
|
This would be super cool if it landed for end to end "0-copy" support in safetensors! I'm working (safetensors/safetensors#767) on adding reading bytes from disk in raw Also, support for Quick question on the |
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing). SHA: 33f52e635db5e6229060481d16a167230a1a474b PR: wjakob/nanobind#1338 Branch: metal-dlpack-cast
https://dmlc.github.io/dlpack/latest/c_api.html#c.DLTensor.data The data pointer points to the allocated data. This will be CUDA device pointer, |
4e16f1d to
a17cd99
Compare
|
One API question: should That would match the mental model used by NumPy/PyTorch more closely: |
I think this is very good design. |
mlx 0.31.2mlx 0.32.0.dev20260523+4e8decde9benchmark function# --- mlx -> torch candidates ---------------------------------------------------
def mlx_to_torch_current(arr: mx.array, device: torch.device) -> torch.Tensor:
arr = mx.contiguous(arr)
mx.eval(arr)
buf = memoryview(arr)
dtype_map = {
mx.float32: torch.float32,
mx.float16: torch.float16,
mx.bfloat16: torch.bfloat16,
}
t = torch.frombuffer(buf, dtype=dtype_map[arr.dtype]).reshape(arr.shape)
if device.type == "mps":
t = t.to(device)
return t
def mlx_to_torch_dlpack_mps(arr: mx.array, device: torch.device) -> torch.Tensor:
mx.eval(arr)
t = torch.from_dlpack(arr)
if device.type == "mps":
t = t.to(device)
return t
def mlx_to_torch_dlpack_cpu(arr: mx.array, device: torch.device) -> torch.Tensor:
"""Force a CPU-typed capsule via `dl_device=(kDLCPU, 0)` (Phase 2+).
Falls back to the no-kwarg form for builds that don't accept it."""
mx.eval(arr)
try:
cap = arr.__dlpack__(dl_device=(1, 0))
except TypeError:
# Older builds: zero-arg lambda. Capsule is already kDLCPU there.
cap = arr.__dlpack__()
return torch.from_dlpack(cap)
# --- torch -> mlx candidates ---------------------------------------------------
def torch_to_mlx_current(t: torch.Tensor) -> mx.array:
if t.device.type != "cpu":
t = t.cpu()
t = t.detach()
if t.dtype == torch.bfloat16:
return mx.array(t)
return mx.array(t.numpy())
def torch_to_mlx_dlpack(t: torch.Tensor) -> mx.array:
"""Use mx.from_dlpack when the API exists; old MLX falls back to CPU copy."""
if hasattr(mx, "from_dlpack"):
return mx.from_dlpack(t)
return mx.array(t.detach().cpu()) |
I updated the PR to follow this design: |
0607c24 to
a4378cf
Compare
zcbenz
left a comment
There was a problem hiding this comment.
This basically looks good to me, thanks for the nice work!
|
I just realized that |
47326da to
4eaea96
Compare
4eaea96 to
a7c4a44
Compare
210c578 to
f733754
Compare
|
CPU imports now copy the underlying storage span and preserve the For Metal:
|
|
The earlier CI failures came from three independent issues. First, the new Second, Third, the macOS Metal validation failure was from a test mutating the original PyTorch MPS tensor after exporting/importing it through DLPack; I think it's a problem with PyTorch METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -c 'import torch; x=torch.tensor([1.0,2.0,3.0], device="mps"); x+=1; torch.mps.synchronize(); print(x.cpu())'will cause the same error without |
|
Sorry, I forgot to update the nanobind version used by the example extension. It was still building against the PyPI 2.12.0 release while MLX itself uses the pinned nanobind commit. I updated the example extension requirements to use the same commit, so the macOS CI failure should be fixed now |
zcbenz
left a comment
There was a problem hiding this comment.
This PR is good to merge, but I think we should wait for a new release in nanobind, I will ask for one.
| } else if (type == nb::dtype<std::complex<double>>()) { | ||
| return nd_array_to_mlx_contiguous<mx::complex128_t>( | ||
| nd_array, shape, dtype.value_or(mx::complex64)); | ||
| return f.template operator()<mx::complex128_t>(mx::complex64); |
There was a problem hiding this comment.
I restored the complex128 -> complex64 mapping because NumPy creates complex arrays as complex128 by default, and the existing NumPy import path used to accept that and cast it to MLX complex64
For Metal DLPack, I think keeping this mapping should still be OK in practice, since most frameworks do not support creating complex128 Metal arrays. If such an array is passed in, the itemsize check will still reject layouts that do not match the MLX dtype size
Proposed changes
This draft adds zero-copy Metal DLPack sharing for MLX arrays and PyTorch MPS tensors.
This PR builds on the merged DLPack import PR #3495 and requires nanobind support.
The main changes are:
byte_offset.mx.from_dlpack(..., copy=...)controls for Metal DLPack inputs.mx.array(...)zero-copy for Metal DLPack inputs unless an explicit different dtype is requested.The shared lifetime is tied to the exported or imported buffer. Synchronization remains explicit: PyTorch writes require
torch.mps.synchronize()before MLX reads, and MLX writes requiremx.eval(...)before PyTorch reads.For MLX arrays exported to PyTorch, later MLX updates may rebind the MLX array to a new buffer while the PyTorch tensor continues to reference the exported buffer.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes